Source code for hysop.backend.device.codegen.symbolic.expr

# Copyright (c) HySoP 2011-2024
# This file is part of HySoP software.
# See ""
# for further info.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from hysop.backend.device.opencl.opencl_types import (
    basetype as cl_basetype,
    components as cl_components,
    vtype as cl_vtype,
from hysop.backend.device.codegen.base.variables import ctype_to_dtype
import sympy as sm
from hysop.symbolic import Symbol, Expr
from hysop.symbolic.array import OpenClSymbolicBuffer, OpenClSymbolicNdBuffer
from import check_instance, first_not_None, to_tuple, to_list
from import is_fp, is_signed, is_unsigned, is_integer, is_complex

from packaging import version

if version.parse(sm.__version__) > version.parse("1.7"):
    from sympy.printing.c import C99CodePrinter
    from sympy.printing.ccode import C99CodePrinter

InstructionTermination = ""

[docs] class TypedI: def __new__(cls, *args, **kwds): positive = kwds.pop("positive", None) obj = super().__new__(cls, *args, **kwds) obj.positive = positive return obj
[docs] @classmethod def vtype(cls, btype, n): return cl_vtype(btype, n)
@property def btype(self): return cl_basetype(self.ctype) @property def basetype(self): return self.btype @property def components(self): return cl_components(self.ctype) @property def dtype(self): return ctype_to_dtype(self.btype) @property def is_signed(self): return is_signed(self.dtype) @property def is_unsigned(self): return is_unsigned(self.dtype) @property def is_integer(self): return is_integer(self.dtype) @property def is_fp(self): return is_fp(self.dtype) @property def is_complex(self): raise NotImplementedError() @property def is_positive(self): return first_not_None(self._positive, self.is_unsigned)
[docs] class TypedSymbol(TypedI, Symbol): def __new__(cls, ctype, **kwds): obj = super().__new__(cls, **kwds) obj.ctype = ctype return obj
[docs] class TypedExpr(TypedI, Expr): def __new__(cls, ctype, *args): try: obj = super().__new__(cls, ctype, *args, evaluate=False) except TypeError: obj = super().__new__(cls, ctype, *args) check_instance(ctype, str) obj.ctype = ctype return obj
[docs] class TypedExprWrapper(TypedExpr): def __new__(cls, ctype, expr): obj = super().__new__(cls, ctype, expr) obj.expr = expr return obj def _ccode(self, printer): return printer._print(self.expr)
[docs] class OpenClConvert(TypedExpr): def __new__(cls, ctype, expr): obj = super().__new__(cls, ctype, expr) obj.expr = expr return obj def _ccode(self, printer): val = printer._print(self.expr) cast = f"convert_{self.ctype}({val})" return cast
[docs] class OpenClCast(TypedExpr): def __new__(cls, ctype, expr): obj = super().__new__(cls, ctype, expr) obj.expr = expr return obj def _ccode(self, printer): expr = printer._print(self.expr) cast = f"({self.ctype})({expr})" return cast
[docs] class OpenClBool(TypedExpr): """ Convert a scalar boolean condition (ie. a int in OpenCL) to a compatible vector boolean condition (ie. all bits set) prior to vectorization. Also force min integer rank. """ def __new__(cls, expr): assert expr.ctype in ("short", "int", "long"), ctype ctype = "char" # force lowest integer rank (to force promotion later) obj = super().__new__(cls, ctype, expr) obj.expr = expr return obj def _ccode(self, printer): # negate scalar boolean to set all bits to 1 (unsigned -1 sets all bits) # (unsigned 0 has not bit set) expr = printer._print(self.expr) # pre-promote result to maximal rank just in case opencl # implementation or runtime fails to yield good type or if # further promotion is needed after. s = f"(-({expr}))" # this breaks conditionals if further promotion is needed # s = '(u{})({})'.format(self.ctype, s) return s
[docs] class Return(Expr): def __new__(cls, expr): obj = super().__new__(cls, expr) obj.expr = expr return obj def _ccode(self, printer): expr = printer._print(self.expr) code = f"return {expr};" ret = printer.codegen.append(code) return InstructionTermination
[docs] class NumericalConstant(TypedExpr): def __new__(cls, ctype, value): obj = super().__new__(cls, ctype, value) obj.value = value return obj def _ccode(self, printer): return printer.typegen.dump(self.value)
[docs] @classmethod def build(cls, val, typegen): ctype = typegen.dumped_type(val) return cls(ctype, val)
[docs] class IntegerConstant(NumericalConstant): pass
[docs] class FloatingPointConstant(NumericalConstant): pass
[docs] class ComplexFloatingPointConstant(NumericalConstant): def _ccode(self, printer): return "(({})({}, {}))".format( self.ctype, printer.typegen.dump(self.value.real), printer.typegen.dump(self.value.imag), )
[docs] class OpenClVariable(TypedExpr): def __new__(cls, ctype, var, *args): obj = super().__new__(cls, ctype, var, *args) obj.var = var return obj def _ccode(self, printer): return self.var()
[docs] class OpenClIndexedVariable(OpenClVariable): def __new__(cls, ctype, var, index): try: dim = index.var.dim components = cl_components(ctype) ctype = cls.vtype(cl_basetype(ctype), components * dim) except AttributeError as e: dim = 1 obj = super().__new__(cls, ctype, var, index) obj.index = index obj.dim = dim return obj def _ccode(self, printer): if not isinstance(self.var, (OpenClSymbolicBuffer, OpenClSymbolicNdBuffer)): try: return self.var[self.index] except Exception as e: pass var = printer._print(self.var) if self.dim > 1: vals = ", ".join(f"{var}[{self.index.var[i]}]" for i in range(self.dim)) return f"({self.ctype})({vals})" else: index = printer._print(self.index) return f"{var}[{index}]"
[docs] class OpenClAssignment(TypedExpr): def __new__(cls, ctype, var, op, rhs): obj = super().__new__(cls, ctype, var, op, rhs) obj.var = var obj.op = op obj.rhs = rhs return obj def _ccode(self, printer): var = printer._print(self.var) rhs = printer._print(self.rhs) code = f"{var} {self.op} {rhs};" printer.codegen.append(code) return InstructionTermination
[docs] class FunctionCall(TypedExpr): def __new__(cls, ctype, fn, fn_kwds): obj = super().__new__(cls, ctype, fn, fn_kwds) obj.fn = fn obj.fn_kwds = fn_kwds return obj def _ccode(self, printer): return self.fn(**self.fn_kwds) def _sympystr(self, printer): return f"FunctionCall({})"
[docs] class VStore(Expr): def __new__(cls, ptr, offset, data, n=1, **opts): obj = super().__new__(cls, ptr, offset, data, n) obj.ptr = ptr obj.offset = offset = data obj.n = n obj.opts = opts return obj def _ccode(self, printer): code = printer.codegen.vstore( n=self.n, ptr=self.ptr, offset=self.offset,, **self.opts ) printer.codegen.append(code) return InstructionTermination
[docs] class VStoreIf(VStore): def __new__(cls, cond, scalar_cond, ptr, offset, data, n, **opts): obj = super().__new__(cls, ptr, offset, data, n) obj.cond = cond obj.scalar_cond = scalar_cond obj.opts = opts return obj def _ccode(self, printer): printer.codegen.vstore_if( cond=self.cond, scalar_cond=self.scalar_cond, n=self.n, ptr=self.ptr, offset=self.offset,, **self.opts, ) return InstructionTermination
[docs] class VLoad(TypedExpr): def __new__(cls, ctype, ptr, offset, dst=None, n=1, **opts): obj = super().__new__(cls, ctype, ptr, offset, dst, n) obj.ptr = ptr obj.offset = offset obj.dst = dst obj.n = n obj.opts = opts return obj def _ccode(self, printer): vload = printer.codegen.vload( n=self.n, ptr=self.ptr, offset=self.offset, **self.opts ) if self.dst: self.dst.affect(printer.codegen, vload) return InstructionTermination else: return vload
[docs] class VLoadIf(VLoad): def __new__(cls, cond, scalar_cond, ptr, offset, dst, n, default_value, **opts): obj = super().__new__(cls, ptr, offset, dst, n) obj.cond = cond obj.scalar_cond = scalar_cond obj.default_value = default_value obj.opts = opts return obj def _ccode(self, printer): printer.codegen.vload_if( cond=self.cond, scalar_cond=self.scalar_cond, n=self.n, ptr=self.ptr, offset=self.offset, dst=self.dst, default_value=self.default_value, **self.opts, ) return InstructionTermination
[docs] class IfElse(Expr): def __new__(cls, conditions, all_exprs, else_exprs=None): conditions = to_tuple(conditions) all_exprs = to_list(all_exprs) else_exprs = to_list(else_exprs) if (else_exprs is not None) else None assert len(all_exprs) >= 1 if not isinstance(all_exprs[0], list): assert len(conditions) == 1 all_exprs = [all_exprs] assert len(conditions) == len(all_exprs) >= 1 obj = super().__new__(cls, conditions, all_exprs, else_exprs) obj.conditions = conditions obj.all_exprs = all_exprs obj.else_exprs = else_exprs return obj def _ccode(self, printer): codegen = printer.codegen for cond, exprs in zip(self.conditions, self.all_exprs): with codegen._if_(cond): for e in exprs: printer._print(e) if self.else_exprs: with codegen._else_(): for e in self.else_exprs: printer._print(e) return InstructionTermination
[docs] class UpdateVars(Expr): def __new__(cls, srcs, dsts, ghosts): obj = super().__new__(cls, srcs, dsts, ghosts) assert srcs and dsts obj.srcs = srcs obj.dsts = dsts obj.init(srcs, dsts, ghosts) return obj
[docs] def init(self, srcs, dsts, ghosts): assert len(srcs) == len(dsts) private_stores = () local_stores = () for src, dst, ghost in zip(srcs, dsts, ghosts): assert not src.is_ptr if dst.is_ptr: assert == "__local" local_stores += ((src, dst, ghost),) else: private_stores += ((src, dst),) self.private_stores = private_stores self.local_stores = local_stores
def _ccode(self, printer): codegen = printer.codegen codegen.jumpline() csc = codegen.comment( "Updating {} from {}".format( ", ".join(x() for x in self.dsts), ", ".join(x() for x in self.srcs) ) ) if self.local_stores: codegen.barrier(_local=True) if self.private_stores: with codegen._align_() as al: for src, dst in self.private_stores: dst.affect(al, init=src, align=True) if self.local_stores: srcs = tuple(map(lambda x: x[0], self.local_stores)) ptrs = tuple(map(lambda x: x[1], self.local_stores)) offsets = tuple(map(lambda x: x[2], self.local_stores)) codegen.multi_vstore_if( csc.is_last_active, lambda i: f"{csc.full_offset}+{i} < {csc.compute_grid_size[0]}", csc.vectorization, csc.local_offset, srcs, ptrs, extra_offsets=offsets, use_short_circuit=csc.use_short_circuit, else_cond=csc.is_active, ) codegen.barrier(_local=True) return InstructionTermination
[docs] class BuiltinFunctionCall(TypedExpr): def __new__(cls, ctype, fname, *fargs): obj = super().__new__(cls, ctype, fname, fargs) obj.fname = fname obj.fargs = fargs return obj def _ccode(self, printer): return "{}({})".format( self.fname, ", ".join(printer._print(arg) for arg in self.fargs) )
[docs] class BuiltinFunction: def __new__(cls, fname): obj = super().__new__(cls) obj.fname = fname return obj def __call__(self, ctype, *args): return BuiltinFunctionCall(ctype, self.fname, *args)
[docs] class OpenClPrinter(C99CodePrinter): _default_settings = { "order": None, "full_prec": "auto", "precision": 15, "user_functions": {}, "human": True, "contract": True, "dereference": set(), "error_on_reserved": False, "reserved_word_suffix": "_", } def __init__(self, typegen, codegen, settings={}, **kwds): super().__init__(settings=settings, **kwds) self.typegen = typegen self.codegen = codegen def _handle_UnevaluatedExpr(self, expr): return expr
[docs] def doprint(self, expr, terminate=True): res = super().doprint(expr) if terminate and (res != InstructionTermination): msg = ( "OpenClPrinter failed to generate code for the following expression:\n" ) msg += f" {expr}\n" msg += f"Returned value was:\n {res}\n" raise RuntimeError(msg) if not terminate: return res